multidms.model_collection fitting pipeline

In the previous example notebook, we saw an explanation of the Data and Model class for fitting, and visualizing the results from a single model. Here, we will see how to use the ModelCollection class and associated utilities to fit multiple models (in parallel using multiprocessing) for aggregation and comparison of the results between fits.

Two very common use cases for this interface include:

  1. Shrinkage analysis of lasso coeff strengths

  2. Training on distinct replicate training datasets

To give an example of each below, we use the multidms.fit_models function to get a collection of fits (in the form of a pandas.DataFrame object) spanning two replicate datasets, and a range of lasso coefficient values. We then instantiate a multidms.ModelCollection object from these fits to aggregate and visualize the results from the fits.

Note

This module functionally wraps the Model interface for convenience. If you’re training on cpu’s and have more than one core in your machine then this is definitely way to go. Currently, the code doesn’t do anything clever to optimize GPU usage by many models training in parallel – Thar be dragons. In the case that would like to use GPU’s for training, it is probably better to train each model individually using the using the fit_one_model function in this module.

[1]:
# import notebook dependencies
import pandas as pd
import multidms
%matplotlib inline

Load functional scores

In the previous example, we showed data from two conditions, and fit a single model to the data. Here, we’ll load multiple replicates of that same data from three deep mutational scanning experiments across Delta, Omicron BA.1, and BA.2 Spike protein.

[2]:
# load scores, and fill wt values with empty strings
func_score_df = pd.read_csv("Delta_BA1_BA2_func_score_df.csv").fillna("")
# split condition and replicate
func_score_df = func_score_df.assign(
    replicate = func_score_df["condition"].apply(lambda x: x.split("-")[-1]),
    condition = func_score_df["condition"].apply(lambda x: "-".join(x.split("-")[:-1]))
)
func_score_df.sample(5)

[2]:
func_score aa_substitutions condition replicate
416375 -0.2593 G75S K1181N Omicron_BA.1 3
174249 -0.7858 H954N Omicron_BA.1 1
449294 -3.5000 L5A A292V G769D G838R F855L W1217L Omicron_BA.2 1
87355 -3.5000 T302A L1141F Delta 3
200762 -0.4590 A484G V622Q A942T Omicron_BA.1 1
[3]:
for condition, cfs in func_score_df.groupby('condition'):
    print(f"{condition} replicates:\n\t{cfs.replicate.unique()}\n")
Delta replicates:
        ['1' '2' '3' '4']

Omicron_BA.1 replicates:
        ['1' '2' '3']

Omicron_BA.2 replicates:
        ['1' '2']

Instantiate multidms.Data objects for fitting

We would like to create two replicate training datasets, each of which should consist of one replicate from each of the three experiments. For simplicity, we’ll group the three experiments deriving from replicate ‘1’ together, and similarly for replicate ‘2’ – keeping in mind there is no significance to the replicate names in this case.

We’ll insatiate the Data objects, as we’ve done before, but this time we’ll create independent Data objects for each replicate. Keep in mind that when comparing across replicate datasets using the multidms.ModelCollection interface, it is best to keep the reference, and non-reference conditions consistent among datasets.

[4]:
data_replicates = [
    multidms.Data(
        func_score_df.query("replicate == @rep"),
        alphabet = multidms.AAS_WITHSTOP_WITHGAP,
        collapse_identical_variants = "mean",
        reference = "Delta",
        verbose = False,
        nb_workers=4,
        name = f"Replicate {rep}"
    )
    for rep in ["1", "2"]
]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1697050301.354487  631537 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Fit one model with multidms.fit_one_model

The model_collection module offers a simple interface to instantiate and fit a Model objects. First, Let’s fit a single model to of the Data replicates instantiated above. To do this, we’ll simply need to define the model parameters.

[11]:
single_set_of_params = {
    "dataset": data_replicates[0], # only one replicate dataset
    "num_training_steps" : 1,
    "iterations_per_step": 5, # Small number of iterations for purposes of this example
    "scale_coeff_lasso_shift": 1e-5,
}

For a full list and descriptions of available hyperparameters, see:

help(multidms.model_collection.fit_one_model)

With these, we can now fit a singular model

[6]:
%load_ext autoreload
%autoreload 2
[12]:
fit = multidms.model_collection.fit_one_model(**single_set_of_params)
fit
[12]:
model                        <multidms.model.Model object at 0x7fb64b9be450>
dataset_name                                                     Replicate 1
step_loss                                                [2.697774660104196]
epistatic_model                                                      Sigmoid
output_activation                                                   Identity
scale_coeff_lasso_shift                                              0.00001
scale_coeff_ridge_beta                                                     0
scale_coeff_ridge_shift                                                    0
scale_coeff_ridge_gamma                                                    0
scale_coeff_ridge_alpha_d                                                  0
huber_scale_huber                                                          1
gamma_corrected                                                        False
alpha_d                                                                False
init_beta_naught                                                         0.0
lock_beta_naught_at                                                     None
tol                                                                   0.0001
num_training_steps                                                         1
iterations_per_step                                                        5
n_hidden_units                                                             5
lower_bound                                                             None
PRNGKey                                                                    0
dtype: object

Now we have the Model object along with the associated hyperparameters that were fit the model to the replicate dataset. Let’s take a look at the beta’s (\(\beta_m\)) from this fit using the Model.mut_param_heatmap method.

[13]:
fit.model.mut_param_heatmap(mut_param="beta")
[13]:

Next, we’ll see how to fit multiple models in parallel.

Fit multiple models (in parallel) with multidms.fit_models

Currently, the model_collection interface offers two public functions: fit_one_model, as we saw above, and fit_models. The former is wrapped by the latter, and allows for multiple models to be fit in parallel by spawning child processes using multiprocessing. The fit_models function takes in a single dictionary which defines the parameter space of all models you wish to run. Each value in the dictionary must be a list of values, even in the case of singletons. This function will compute all combinations of the parameter space and pass each combination to :func:multidms.utils.fit_wrapper to be run in parallel, thus only key-value pairs which match the fit_one_model kwargs are allowed.

To exemplify this, let’s again define the hyperparameters, but this time, we’ll specify each value as a list of values to be fit in parallel.

[ ]:
# test out no free alpha_d param
collection_params = {
    "dataset": data_replicates,
    "num_training_steps" : [20],
    "iterations_per_step": [1000],
    "output_activation" : ["Softplus"],
    "lower_bound" : [-3.5],
    "scale_coeff_lasso_shift": [0.0, 1e-5, 1e-3],
}

Before we fit the models, let’s take a look at what collection of models we’re specifying with this dictionary by calling upon a “private” function multidms.model_collection._explode_params_dict. As implied by the “private” this functionality behavior is hidden from the user and is performed intrinsically when calling fit_models.

[144]:
from pprint import pprint
pprint(multidms.model_collection._explode_params_dict(collection_params))
[{'alpha_d': False,
  'dataset': <multidms.data.Data object at 0x7fb5de1b8350>,
  'iterations_per_step': 1000,
  'lower_bound': -3.5,
  'num_training_steps': 20,
  'output_activation': 'Softplus',
  'scale_coeff_lasso_shift': 0.0},
 {'alpha_d': False,
  'dataset': <multidms.data.Data object at 0x7fb5de1b8350>,
  'iterations_per_step': 1000,
  'lower_bound': -3.5,
  'num_training_steps': 20,
  'output_activation': 'Softplus',
  'scale_coeff_lasso_shift': 1e-05},
 {'alpha_d': False,
  'dataset': <multidms.data.Data object at 0x7fb5de1b8350>,
  'iterations_per_step': 1000,
  'lower_bound': -3.5,
  'num_training_steps': 20,
  'output_activation': 'Softplus',
  'scale_coeff_lasso_shift': 0.001},
 {'alpha_d': False,
  'dataset': <multidms.data.Data object at 0x7fb5f79b6fd0>,
  'iterations_per_step': 1000,
  'lower_bound': -3.5,
  'num_training_steps': 20,
  'output_activation': 'Softplus',
  'scale_coeff_lasso_shift': 0.0},
 {'alpha_d': False,
  'dataset': <multidms.data.Data object at 0x7fb5f79b6fd0>,
  'iterations_per_step': 1000,
  'lower_bound': -3.5,
  'num_training_steps': 20,
  'output_activation': 'Softplus',
  'scale_coeff_lasso_shift': 1e-05},
 {'alpha_d': False,
  'dataset': <multidms.data.Data object at 0x7fb5f79b6fd0>,
  'iterations_per_step': 1000,
  'lower_bound': -3.5,
  'num_training_steps': 20,
  'output_activation': 'Softplus',
  'scale_coeff_lasso_shift': 0.001}]

What is produced is a list of **kwargs to pass to fit_one_model. In this case there are 6 total models to fit (2 replicate datasets x 3 lasso strengths). To fit these models, we simply pass the collection_params to fit_models and specify the number of threads available to run the model fits in parallel.

[10]:
n_fit, n_failed, fit_models = multidms.model_collection.fit_models(collection_params, n_threads=4)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696946778.812740  280061 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696946779.115679  280058 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696946779.257901  280060 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696946779.733650  280059 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[11]:
print(f"Of the 6 model fits, {n_fit} succeeded and {n_failed} failed")
Of the 6 model fits, 6 succeeded and 0 failed

The third object returned by fit_models is a pandas.DataFrame object which contains the results from each model fit by stacking the pd.Series objects as returned by fit_one_model.

[146]:
fit_models
[146]:
model dataset_name step_loss epistatic_model output_activation scale_coeff_lasso_shift scale_coeff_ridge_beta scale_coeff_ridge_shift scale_coeff_ridge_gamma scale_coeff_ridge_ch ... gamma_corrected alpha_d init_beta_naught lock_beta_naught_at tol num_training_steps iterations_per_step n_hidden_units lower_bound PRNGKey
0 <multidms.model.Model object at 0x7fb581be2550> Replicate 1 [1.7568787605449558, 1.6370501661172565, 1.604... Sigmoid Softplus 0.0 0 0 0 0 ... False False 0.0 None 0.0001 20 1000 5 -3.5 0
1 <multidms.model.Model object at 0x7fb581e8d690> Replicate 1 [1.7818608886109695, 1.2296774475140684, 0.979... Sigmoid Softplus 0.00001 0 0 0 0 ... False False 0.0 None 0.0001 20 1000 5 -3.5 0
2 <multidms.model.Model object at 0x7fb581e8f850> Replicate 1 [1.6018007367535427, 1.2163388879503787, 1.124... Sigmoid Softplus 0.001 0 0 0 0 ... False False 0.0 None 0.0001 20 1000 5 -3.5 0
3 <multidms.model.Model object at 0x7fb572c64390> Replicate 2 [1.1068374170919253, 0.8820706887917396, 0.807... Sigmoid Softplus 0.0 0 0 0 0 ... False False 0.0 None 0.0001 20 1000 5 -3.5 0
4 <multidms.model.Model object at 0x7fb581be0f90> Replicate 2 [1.1264708733781783, 0.9064847516719441, 0.834... Sigmoid Softplus 0.00001 0 0 0 0 ... False False 0.0 None 0.0001 20 1000 5 -3.5 0
5 <multidms.model.Model object at 0x7fb580645050> Replicate 2 [1.3794643479819475, 1.1491574484300826, 1.067... Sigmoid Softplus 0.001 0 0 0 0 ... False False 0.0 None 0.0001 20 1000 5 -3.5 0

6 rows × 21 columns

This DataFrame is all that’s necessary to insatiate a multidms.ModelCollection object.

Note

If you wanted to use a pipeline to farm out the fitting processes independently, the same DataFrame could be acquired by collecting the individual Series objects returned by fit_one_model, then concatenated using the simple multidms.model_collection.stack_fit_models utility.

ModelCollection Object

The ModelCollection class is simply a nice interface aplit-apply-combine the model attributes such as mutations dataframes, and variants_df contained within the pandas.DataFrame object returned by fit_models. To instantiate a ModelCollection object, we simply pass the dataframe to the constructor.

[138]:
mc = multidms.ModelCollection(fit_models)

To get raw data in a nice tidy format, ModelCollection.split_apply_combine_muts has a straightforward name for a simple goal. This function follows the split-apply-combine paradigm to the collection of individual mutational effects tables (our example currently has 6) while keeping the fit hyperparameters of interest, tied to the data.

[130]:
combined_lasso_strengths = mc.split_apply_combine_muts(
    groupby=["dataset_name", "scale_coeff_lasso_shift"]
)
combined_lasso_strengths.head()
[130]:
mutation beta shift_Omicron_BA.1 shift_Omicron_BA.2 predicted_func_score_Delta predicted_func_score_Omicron_BA.1 predicted_func_score_Omicron_BA.2 times_seen_Delta times_seen_Omicron_BA.1 times_seen_Omicron_BA.2
dataset_name scale_coeff_lasso_shift
Replicate 1 0.0 A1015D 0.121804 0.000000 -0.013379 0.047931 0.098184 2.298876 2.0 0.0 3.0
0.0 A1015Q -0.580047 -0.014002 0.000000 -0.281579 -0.239191 1.974444 0.0 1.0 0.0
0.0 A1015S 2.255091 -2.769225 -0.027141 0.437377 -0.194953 2.691718 5.0 10.0 29.0
0.0 A1015T -1.023087 -3.875584 -0.133883 -0.551092 -2.116062 1.616608 8.0 8.0 22.0
0.0 A1015V -2.074877 -4.159578 -0.046739 -1.259227 -2.181602 0.967254 5.0 7.0 7.0

The fit collection groupby features (scale_coeff_lasso_shift, and dataset_name in this case) are set as a multiindex – the index then easily distinguishes fit groups from from mutation features, and is more memory efficient. If groupby = None (default), then we group by all available fit attributes. Also note that by default, only mutations shared by all datasets are returned, but this can be changed by setting inner_merge_dataset_muts=False.

Mutational parameter heatmaps

Just as you might use Model.mut_param_heatmap to visualize the mutation effects from a single model, you can use ModelCollection.mut_param_heatmap to visualize the aggregated mutation effects from a collection of models fit to multiple replicate datasets.

Using all defaults this would be called as follows:

heatmap_chart = mc.mut_param_heatmap()

However, our current example fit collection has 3 different lasso strengths, which don’t make sense to aggregate over. Thus, this call will result in:

ValueError: invalid query, more than one unique hyper-parameter besides dataset_name

To fix this, we must subset out model collection such that we are only aggregating across different training datasets.

[135]:
chart = mc.mut_param_heatmap(
    query="scale_coeff_lasso_shift == 1e-5",
    mut_param="beta"
)
chart
[135]:

Here, we visualized the models beta (\(\beta_m\)) parameters, but we can also visualize the respective shift parameters for each non-reference condition (\(\Delta_{d,m}\)) by setting param='shift'.

[136]:
chart = mc.mut_param_heatmap(
    query="scale_coeff_lasso_shift == 1e-5",
    mut_param="shift" # or "shift"
)
chart
[136]:

Or, we can visualize the mutation predictions (\(\hat{y}_{m, d}\)), noting that by default, we are viewing the predictions with phenotype as effect (difference from non-zero wildtype prediction).

[137]:
chart = mc.mut_param_heatmap(
    query="scale_coeff_lasso_shift == 1e-5",
    mut_param="predicted_func_score",
    phenotype_as_effect=True
)
chart
[137]:

Trace charts for mutational shrinkage

Another common reason you might fit a collection of models is to test multiple lasso strength coefficients. When you have a few mutations of interest, you might want to see how the lasso strength affects the shrinkage of the mutation effects. To do this, we can use the ModelCollection.mut_param_traceplot method.

Begin by selecting a subset of mutations to visualize. Here, we’ll select the top 10 mutations by absolute value of the beta parameter accross all fits.

[142]:
combined_lasso_strengths["abs_beta"] = combined_lasso_strengths["beta"].abs()
muts_of_interest = combined_lasso_strengths.sort_values("abs_beta", ascending=False).head(10).mutation.values
mc.mut_param_traceplot(mutations = muts_of_interest, mut_param="shift")
[142]: